﻿//***************************************************************************************
// Ssao.cpp by Frank Luna (C) 2011 All Rights Reserved.
//***************************************************************************************

#include "Ssao.h"
#include <DirectXPackedVector.h>
#include <graphics/imath.h>

using namespace DirectX;
using namespace DirectX::PackedVector;
using namespace Microsoft::WRL;

namespace ifire::dx {

Ssao::Ssao(ID3D12Device* device, ID3D12GraphicsCommandList* cmdList, UINT width,
    UINT height)

{
  md3dDevice = device;

  OnResize(width, height);

  BuildOffsetVectors();
  BuildRandomVectorTexture(cmdList);
}

UINT Ssao::SsaoMapWidth() const { return mRenderTargetWidth / 2; }

UINT Ssao::SsaoMapHeight() const { return mRenderTargetHeight / 2; }

void Ssao::GetOffsetVectors(DirectX::XMFLOAT4 offsets[14]) {
  std::copy(&mOffsets[0], &mOffsets[14], &offsets[0]);
}

eastl::vector<float> Ssao::CalcGaussWeights(float sigma) {
  float twoSigma2 = 2.0f * sigma * sigma;

  // Estimate the blur radius based on sigma since sigma controls the "width" of
  // the bell curve. For example, for sigma = 3, the width of the bell curve is
  int blurRadius = (int)ceil(2.0f * sigma);

  assert(blurRadius <= MaxBlurRadius);

  eastl::vector<float> weights;
  weights.resize(2 * blurRadius + 1);

  float weightSum = 0.0f;

  for (int i = -blurRadius; i <= blurRadius; ++i) {
    float x = (float)i;

    weights[i + blurRadius] = expf(-x * x / twoSigma2);

    weightSum += weights[i + blurRadius];
  }

  // Divide by the sum so all the weights add up to 1.0.
  for (int i = 0; i < weights.size(); ++i) {
    weights[i] /= weightSum;
  }

  return weights;
}

ID3D12Resource* Ssao::NormalMap() { return mNormalMap.Get(); }

ID3D12Resource* Ssao::AmbientMap() { return mAmbientMap0.Get(); }

CD3DX12_CPU_DESCRIPTOR_HANDLE Ssao::NormalMapRtv() const {
  return mhNormalMapCpuRtv;
}

CD3DX12_GPU_DESCRIPTOR_HANDLE Ssao::NormalMapSrv() const {
  return mhNormalMapGpuSrv;
}

CD3DX12_GPU_DESCRIPTOR_HANDLE Ssao::AmbientMapSrv() const {
  return mhAmbientMap0GpuSrv;
}

void Ssao::BuildDescriptors(ID3D12Resource* depthStencilBuffer,
    CD3DX12_CPU_DESCRIPTOR_HANDLE hCpuSrv,
    CD3DX12_GPU_DESCRIPTOR_HANDLE hGpuSrv,
    CD3DX12_CPU_DESCRIPTOR_HANDLE hCpuRtv, UINT cbvSrvUavDescriptorSize,
    UINT rtvDescriptorSize) {
  // Save references to the descriptors.  The Ssao reserves heap space
  // for 5 contiguous Srvs.

  mhAmbientMap0CpuSrv = hCpuSrv;
  mhAmbientMap1CpuSrv = hCpuSrv.Offset(1, cbvSrvUavDescriptorSize);
  mhNormalMapCpuSrv = hCpuSrv.Offset(1, cbvSrvUavDescriptorSize);
  mhDepthMapCpuSrv = hCpuSrv.Offset(1, cbvSrvUavDescriptorSize);
  mhRandomVectorMapCpuSrv = hCpuSrv.Offset(1, cbvSrvUavDescriptorSize);

  mhAmbientMap0GpuSrv = hGpuSrv;
  mhAmbientMap1GpuSrv = hGpuSrv.Offset(1, cbvSrvUavDescriptorSize);
  mhNormalMapGpuSrv = hGpuSrv.Offset(1, cbvSrvUavDescriptorSize);
  mhDepthMapGpuSrv = hGpuSrv.Offset(1, cbvSrvUavDescriptorSize);
  mhRandomVectorMapGpuSrv = hGpuSrv.Offset(1, cbvSrvUavDescriptorSize);

  mhNormalMapCpuRtv = hCpuRtv;
  mhAmbientMap0CpuRtv = hCpuRtv.Offset(1, rtvDescriptorSize);
  mhAmbientMap1CpuRtv = hCpuRtv.Offset(1, rtvDescriptorSize);

  //  Create the descriptors
  RebuildDescriptors(depthStencilBuffer);
}

void Ssao::RebuildDescriptors(ID3D12Resource* depthStencilBuffer) {
  D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc = {};
  srvDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING;
  srvDesc.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2D;
  srvDesc.Format = NormalMapFormat;
  srvDesc.Texture2D.MostDetailedMip = 0;
  srvDesc.Texture2D.MipLevels = 1;
  md3dDevice->CreateShaderResourceView(
      mNormalMap.Get(), &srvDesc, mhNormalMapCpuSrv);

  srvDesc.Format = DXGI_FORMAT_R24_UNORM_X8_TYPELESS;
  md3dDevice->CreateShaderResourceView(
      depthStencilBuffer, &srvDesc, mhDepthMapCpuSrv);

  srvDesc.Format = DXGI_FORMAT_R8G8B8A8_UNORM;
  md3dDevice->CreateShaderResourceView(
      mRandomVectorMap.Get(), &srvDesc, mhRandomVectorMapCpuSrv);

  srvDesc.Format = AmbientMapFormat;
  md3dDevice->CreateShaderResourceView(
      mAmbientMap0.Get(), &srvDesc, mhAmbientMap0CpuSrv);
  md3dDevice->CreateShaderResourceView(
      mAmbientMap1.Get(), &srvDesc, mhAmbientMap1CpuSrv);

  D3D12_RENDER_TARGET_VIEW_DESC rtvDesc = {};
  rtvDesc.ViewDimension = D3D12_RTV_DIMENSION_TEXTURE2D;
  rtvDesc.Format = NormalMapFormat;
  rtvDesc.Texture2D.MipSlice = 0;
  rtvDesc.Texture2D.PlaneSlice = 0;
  md3dDevice->CreateRenderTargetView(
      mNormalMap.Get(), &rtvDesc, mhNormalMapCpuRtv);

  rtvDesc.Format = AmbientMapFormat;
  md3dDevice->CreateRenderTargetView(
      mAmbientMap0.Get(), &rtvDesc, mhAmbientMap0CpuRtv);
  md3dDevice->CreateRenderTargetView(
      mAmbientMap1.Get(), &rtvDesc, mhAmbientMap1CpuRtv);
}

void Ssao::SetPSOs(
    ID3D12PipelineState* ssaoPso, ID3D12PipelineState* ssaoBlurPso) {
  mSsaoPso = ssaoPso;
  mBlurPso = ssaoBlurPso;
}

void Ssao::OnResize(UINT newWidth, UINT newHeight) {
  if (mRenderTargetWidth != newWidth || mRenderTargetHeight != newHeight) {
    mRenderTargetWidth = newWidth;
    mRenderTargetHeight = newHeight;

    // We render to ambient map at half the resolution.
    mViewport.TopLeftX = 0.0f;
    mViewport.TopLeftY = 0.0f;
    mViewport.Width = mRenderTargetWidth / 2.0f;
    mViewport.Height = mRenderTargetHeight / 2.0f;
    mViewport.MinDepth = 0.0f;
    mViewport.MaxDepth = 1.0f;

    mScissorRect = {
        0, 0, (int)mRenderTargetWidth / 2, (int)mRenderTargetHeight / 2};

    BuildResources();
  }
}

void Ssao::ComputeSsao(ID3D12GraphicsCommandList* cmdList,
    FrameResource* currFrame, int blurCount) {
  cmdList->RSSetViewports(1, &mViewport);
  cmdList->RSSetScissorRects(1, &mScissorRect);

  // We compute the initial SSAO to AmbientMap0.

  // Change to RENDER_TARGET.
  auto to_target = CD3DX12_RESOURCE_BARRIER::Transition(mAmbientMap0.Get(),
      D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_STATE_RENDER_TARGET);
  cmdList->ResourceBarrier(1, &to_target);

  float clearValue[] = {1.0f, 1.0f, 1.0f, 1.0f};
  cmdList->ClearRenderTargetView(mhAmbientMap0CpuRtv, clearValue, 0, nullptr);

  // Specify the buffers we are going to render to.
  cmdList->OMSetRenderTargets(1, &mhAmbientMap0CpuRtv, true, nullptr);

  // Bind the constant buffer for this pass.
  auto ssaoCBAddress = currFrame->ssao_cb->Resource()->GetGPUVirtualAddress();
  cmdList->SetGraphicsRootConstantBufferView(0, ssaoCBAddress);
  cmdList->SetGraphicsRoot32BitConstant(1, 0, 0);

  // Bind the normal and depth maps.
  cmdList->SetGraphicsRootDescriptorTable(2, mhNormalMapGpuSrv);

  // Bind the random vector map.
  cmdList->SetGraphicsRootDescriptorTable(3, mhRandomVectorMapGpuSrv);

  cmdList->SetPipelineState(mSsaoPso);

  // Draw fullscreen quad.
  cmdList->IASetVertexBuffers(0, 0, nullptr);
  cmdList->IASetIndexBuffer(nullptr);
  cmdList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
  cmdList->DrawInstanced(6, 1, 0, 0);

  // Change back to GENERIC_READ so we can read the texture in a shader.
  auto to_read = CD3DX12_RESOURCE_BARRIER::Transition(mAmbientMap0.Get(),
      D3D12_RESOURCE_STATE_RENDER_TARGET, D3D12_RESOURCE_STATE_GENERIC_READ);
  cmdList->ResourceBarrier(1, &to_read);

  BlurAmbientMap(cmdList, currFrame, blurCount);
}

void Ssao::BlurAmbientMap(ID3D12GraphicsCommandList* cmdList,
    FrameResource* currFrame, int blurCount) {
  cmdList->SetPipelineState(mBlurPso);

  auto ssaoCBAddress = currFrame->ssao_cb->Resource()->GetGPUVirtualAddress();
  cmdList->SetGraphicsRootConstantBufferView(0, ssaoCBAddress);

  for (int i = 0; i < blurCount; ++i) {
    BlurAmbientMap(cmdList, true);
    BlurAmbientMap(cmdList, false);
  }
}

void Ssao::BlurAmbientMap(ID3D12GraphicsCommandList* cmdList, bool horzBlur) {
  ID3D12Resource* output = nullptr;
  CD3DX12_GPU_DESCRIPTOR_HANDLE inputSrv;
  CD3DX12_CPU_DESCRIPTOR_HANDLE outputRtv;

  // Ping-pong the two ambient map textures as we apply
  // horizontal and vertical blur passes.
  if (horzBlur == true) {
    output = mAmbientMap1.Get();
    inputSrv = mhAmbientMap0GpuSrv;
    outputRtv = mhAmbientMap1CpuRtv;
    cmdList->SetGraphicsRoot32BitConstant(1, 1, 0);
  } else {
    output = mAmbientMap0.Get();
    inputSrv = mhAmbientMap1GpuSrv;
    outputRtv = mhAmbientMap0CpuRtv;
    cmdList->SetGraphicsRoot32BitConstant(1, 0, 0);
  }
  auto to_target = CD3DX12_RESOURCE_BARRIER::Transition(output,
      D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_STATE_RENDER_TARGET);
  cmdList->ResourceBarrier(1, &to_target);

  float clearValue[] = {1.0f, 1.0f, 1.0f, 1.0f};
  cmdList->ClearRenderTargetView(outputRtv, clearValue, 0, nullptr);

  cmdList->OMSetRenderTargets(1, &outputRtv, true, nullptr);

  // Normal/depth map still bound.

  // Bind the normal and depth maps.
  cmdList->SetGraphicsRootDescriptorTable(2, mhNormalMapGpuSrv);

  // Bind the input ambient map to second texture table.
  cmdList->SetGraphicsRootDescriptorTable(3, inputSrv);

  // Draw fullscreen quad.
  cmdList->IASetVertexBuffers(0, 0, nullptr);
  cmdList->IASetIndexBuffer(nullptr);
  cmdList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
  cmdList->DrawInstanced(6, 1, 0, 0);

  auto to_read = CD3DX12_RESOURCE_BARRIER::Transition(output,
      D3D12_RESOURCE_STATE_RENDER_TARGET, D3D12_RESOURCE_STATE_GENERIC_READ);
  cmdList->ResourceBarrier(1, &to_read);
}

void Ssao::BuildResources() {
  // Free the old resources if they exist.
  mNormalMap = nullptr;
  mAmbientMap0 = nullptr;
  mAmbientMap1 = nullptr;

  D3D12_RESOURCE_DESC texDesc;
  ZeroMemory(&texDesc, sizeof(D3D12_RESOURCE_DESC));
  texDesc.Dimension = D3D12_RESOURCE_DIMENSION_TEXTURE2D;
  texDesc.Alignment = 0;
  texDesc.Width = mRenderTargetWidth;
  texDesc.Height = mRenderTargetHeight;
  texDesc.DepthOrArraySize = 1;
  texDesc.MipLevels = 1;
  texDesc.Format = Ssao::NormalMapFormat;
  texDesc.SampleDesc.Count = 1;
  texDesc.SampleDesc.Quality = 0;
  texDesc.Layout = D3D12_TEXTURE_LAYOUT_UNKNOWN;
  texDesc.Flags = D3D12_RESOURCE_FLAG_ALLOW_RENDER_TARGET;

  float normalClearColor[] = {0.0f, 0.0f, 1.0f, 0.0f};
  CD3DX12_CLEAR_VALUE optClear(NormalMapFormat, normalClearColor);
  auto pdefault = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT);
  HASSERT(md3dDevice->CreateCommittedResource(&pdefault,
      D3D12_HEAP_FLAG_NONE, &texDesc, D3D12_RESOURCE_STATE_COMMON, &optClear,
      IID_PPV_ARGS(&mNormalMap)));

  // Ambient occlusion maps are at half resolution.
  texDesc.Width = mRenderTargetWidth / 2;
  texDesc.Height = mRenderTargetHeight / 2;
  texDesc.Format = Ssao::AmbientMapFormat;

  float ambientClearColor[] = {1.0f, 1.0f, 1.0f, 1.0f};
  optClear = CD3DX12_CLEAR_VALUE(AmbientMapFormat, ambientClearColor);

  HASSERT(md3dDevice->CreateCommittedResource(&pdefault,
      D3D12_HEAP_FLAG_NONE, &texDesc, D3D12_RESOURCE_STATE_GENERIC_READ,
      &optClear, IID_PPV_ARGS(&mAmbientMap0)));

  HASSERT(md3dDevice->CreateCommittedResource(&pdefault,
      D3D12_HEAP_FLAG_NONE, &texDesc, D3D12_RESOURCE_STATE_GENERIC_READ,
      &optClear, IID_PPV_ARGS(&mAmbientMap1)));
}

void Ssao::BuildRandomVectorTexture(ID3D12GraphicsCommandList* cmdList) {
  D3D12_RESOURCE_DESC texDesc;
  ZeroMemory(&texDesc, sizeof(D3D12_RESOURCE_DESC));
  texDesc.Dimension = D3D12_RESOURCE_DIMENSION_TEXTURE2D;
  texDesc.Alignment = 0;
  texDesc.Width = 256;
  texDesc.Height = 256;
  texDesc.DepthOrArraySize = 1;
  texDesc.MipLevels = 1;
  texDesc.Format = DXGI_FORMAT_R8G8B8A8_UNORM;
  texDesc.SampleDesc.Count = 1;
  texDesc.SampleDesc.Quality = 0;
  texDesc.Layout = D3D12_TEXTURE_LAYOUT_UNKNOWN;
  texDesc.Flags = D3D12_RESOURCE_FLAG_NONE;

  auto pdefault = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT);
  HASSERT(md3dDevice->CreateCommittedResource(&pdefault,
      D3D12_HEAP_FLAG_NONE, &texDesc, D3D12_RESOURCE_STATE_GENERIC_READ,
      nullptr, IID_PPV_ARGS(&mRandomVectorMap)));

  //
  // In order to copy CPU memory data into our default buffer, we need to create
  // an intermediate upload heap.
  //

  const UINT num2DSubresources = texDesc.DepthOrArraySize * texDesc.MipLevels;
  const UINT64 uploadBufferSize =
      GetRequiredIntermediateSize(mRandomVectorMap.Get(), 0, num2DSubresources);

  auto upload = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD);
  auto buffer = CD3DX12_RESOURCE_DESC::Buffer(uploadBufferSize);
  HASSERT(md3dDevice->CreateCommittedResource(&upload,
      D3D12_HEAP_FLAG_NONE, &buffer, D3D12_RESOURCE_STATE_GENERIC_READ, nullptr,
      IID_PPV_ARGS(mRandomVectorMapUploadBuffer.GetAddressOf())));

  XMCOLOR initData[256 * 256];
  for (int i = 0; i < 256; ++i) {
    for (int j = 0; j < 256; ++j) {
      // Random vector in [0,1].  We will decompress in shader to [-1,1].
      XMFLOAT3 v(imath::RandF(), imath::RandF(), imath::RandF());

      initData[i * 256 + j] = XMCOLOR(v.x, v.y, v.z, 0.0f);
    }
  }

  D3D12_SUBRESOURCE_DATA subResourceData = {};
  subResourceData.pData = initData;
  subResourceData.RowPitch = 256 * sizeof(XMCOLOR);
  subResourceData.SlicePitch = subResourceData.RowPitch * 256;

  //
  // Schedule to copy the data to the default resource, and change states.
  // Note that mCurrSol is put in the GENERIC_READ state so it can be
  // read by a shader.
  //

  auto to_dest = CD3DX12_RESOURCE_BARRIER::Transition(mRandomVectorMap.Get(),
      D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_STATE_COPY_DEST);
  cmdList->ResourceBarrier(1, &to_dest);
  UpdateSubresources(cmdList, mRandomVectorMap.Get(),
      mRandomVectorMapUploadBuffer.Get(), 0, 0, num2DSubresources,
      &subResourceData);
  auto to_read = CD3DX12_RESOURCE_BARRIER::Transition(mRandomVectorMap.Get(),
      D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_GENERIC_READ);
  cmdList->ResourceBarrier(1, &to_read);
}

void Ssao::BuildOffsetVectors() {
  // Start with 14 uniformly distributed vectors.  We choose the 8 corners of
  // the cube and the 6 center points along each cube face.  We always alternate
  // the points on opposites sides of the cubes.  This way we still get the
  // vectors spread out even if we choose to use less than 14 samples.

  // 8 cube corners
  mOffsets[0] = XMFLOAT4(+1.0f, +1.0f, +1.0f, 0.0f);
  mOffsets[1] = XMFLOAT4(-1.0f, -1.0f, -1.0f, 0.0f);

  mOffsets[2] = XMFLOAT4(-1.0f, +1.0f, +1.0f, 0.0f);
  mOffsets[3] = XMFLOAT4(+1.0f, -1.0f, -1.0f, 0.0f);

  mOffsets[4] = XMFLOAT4(+1.0f, +1.0f, -1.0f, 0.0f);
  mOffsets[5] = XMFLOAT4(-1.0f, -1.0f, +1.0f, 0.0f);

  mOffsets[6] = XMFLOAT4(-1.0f, +1.0f, -1.0f, 0.0f);
  mOffsets[7] = XMFLOAT4(+1.0f, -1.0f, +1.0f, 0.0f);

  // 6 centers of cube faces
  mOffsets[8] = XMFLOAT4(-1.0f, 0.0f, 0.0f, 0.0f);
  mOffsets[9] = XMFLOAT4(+1.0f, 0.0f, 0.0f, 0.0f);

  mOffsets[10] = XMFLOAT4(0.0f, -1.0f, 0.0f, 0.0f);
  mOffsets[11] = XMFLOAT4(0.0f, +1.0f, 0.0f, 0.0f);

  mOffsets[12] = XMFLOAT4(0.0f, 0.0f, -1.0f, 0.0f);
  mOffsets[13] = XMFLOAT4(0.0f, 0.0f, +1.0f, 0.0f);

  for (int i = 0; i < 14; ++i) {
    // Create random lengths in [0.25, 1.0].
    float s = imath::RandF(0.25f, 1.0f);

    XMVECTOR v = s * XMVector4Normalize(XMLoadFloat4(&mOffsets[i]));

    XMStoreFloat4(&mOffsets[i], v);
  }
}

} // namespace ifire